TE EP integration to MoEBlock#3116
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
0ff3bff to
bd14fe6
Compare
| # Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0 | ||
| # uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM | ||
| # tile alignment. | ||
| align_size: int = 0 |
There was a problem hiding this comment.
Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user
| nn.with_logical_partitioning(self.bias_init, ("exp",)), | ||
| (self.num_experts,), | ||
| self.dtype, | ||
| jnp.float32, |
There was a problem hiding this comment.
Is the router always in fp32 so this expert bias must also be? If so, can we add a small comment indicating this
There was a problem hiding this comment.
yes I will add a comment
|
|
||
|
|
||
| __all__ = ["moe", "PermutationBackend"] | ||
| def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray: |
There was a problem hiding this comment.
Why do we need this utility function? I haven't seen something like this required for our other VJPs
There was a problem hiding this comment.
for MoE particularly, In our _moe_bwd_rule, d_x is built from two cotangent paths:
d_x_from_dispatch = tex.ep_dispatch_bwd(...) in bf16 if x is in bf16.
and
d_x_from_gate = d_logits_2d @ gate_kernel^T, where d_logits_2d comes from tex.fused_topk_with_score_function_bwd, which only ooutput the logits in fp32 (per alp in one of the weekly meetings + I checked the code for the router kernel). So d_x_from_gate is fp32.
Therefore, d_x = d_x_from_dispatch + d_x_from_gate = bf16 + fp32 = fp32, while our x could be bf16. So the cotangent that flows to bwd needs to be constrainted. Other VJPs don't have this issue because the dgrad will just be the same dtype as the activation dtype.
| # is a frozen dataclass of ints); the rest are jnp.ndarray, | ||
| # GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0. | ||
| @register_pytree_node_class | ||
| @dataclass |
There was a problem hiding this comment.
I think this tree_flatten was from my patch, but looking at the diff I think it'd be better to use the @flax_struct.dataclass you were using on the permutation dataclasses since that seems to auto-populate a default pytree flatten/unflatten for us
| else: | ||
| d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) | ||
|
|
||
| # Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply |
There was a problem hiding this comment.
Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block
There was a problem hiding this comment.
Yes, after confirming with Tim and Przemek, we should kjeep activation input in BF16 and activation output in GEMM's dtype. In latest commit, this should be BF16
| # local expert. We must size to that worst case or NCCL EP's HT kernel | ||
| # rejects the dispatch buffer with ``invalid argument``. | ||
| natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S | ||
| # NCCL EP requires each expert-major output block to be at least |
There was a problem hiding this comment.
Do we have a use-case for user-specified alignments beyond 128 currently? If NCCL EP requires an alignment of at least 128, and since an alignment of 128 is sufficient for all TE grouped GEMM types, would it make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API.
We can always expand the API to support a user-specified align size in the future
| batch_pspec_axis = (*data_parallelism_axes, ep_axis) | ||
| ep3_spec = P(batch_pspec_axis, None, None) | ||
| ep2_spec = P(batch_pspec_axis, None) | ||
| x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) |
There was a problem hiding this comment.
Which axis name inputs are physical mesh axes and why can be logical axes? I see above x = with_sharding_constraint_by_logical_axes(x, input_axes) but here we directly use jax.lax.with_sharding_constraint which only supports mesh axes.
No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes. Thanks!
There was a problem hiding this comment.
added some comments
| # `grad_pre_combine * w` sees them. Padded positions in sparse_probs | ||
| # are already zero (routing_map is False there); only the rare | ||
| # underflow path emits NaN. | ||
| sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype) |
There was a problem hiding this comment.
Is this NaN filtering a debugging artifact or something we need in the final version?
There was a problem hiding this comment.
debugging artifact. Remopving
…IA#3116) Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias -> use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The two flags are siblings (they enable two different bias buffers) but the old names suggested ``use_bias`` was the general fallback, which wasn't the intent. The new names make the FFN-vs-routing distinction obvious from the call site. * transformer_engine/jax/flax/moe.py use_bias -> use_ffn_bias (dataclass field + branch in __call__ + docstring entry) use_expert_bias -> use_expert_routing_bias (same) * tests/jax/test_te_ep_moe.py _make_block(use_expert_bias=...) -> use_expert_routing_bias sigmoid-bias-strong config key updated _reference_kwargs_from_config now reads use_expert_routing_bias ``_MoEBlock`` is still the experimental underscore-prefixed alias (no public ``MoEBlock`` export yet), so the rename is API-safe. The pre-resync legacy tests (``test_moe_vjp.py``, ``test_multiprocess_moe_vjp.py``) are intentionally not updated -- they already reference removed APIs like ``PermutationBackend`` and need a separate post-resync cleanup pass. Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications) Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR NVIDIA#3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
…aN sanitizer * Rewrite the inline justifications added in 078a7d80 so each one reads as standalone code documentation, not as a reply to a reviewer: drop "per PR NVIDIA#3116 review", "review feedback", "Renamed from ... per PR ..." and similar PR/thread references from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py. Technical content (why the fp32 promotion is needed for the MoE silu+multiply, why _with_sharding_constraint_cast_bwd exists, physical-vs-logical axis split in moe() docstring, the 128 alignment rationale) is preserved and reframed to be useful to a reader who has no PR context. * Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs) guard. Tracing fused_topk_with_score_function.cu shows the kernel divides by sum_scores + 1e-20, so finite non-negative sigmoid scores cannot produce NaN here; the filter was only defense against upstream NaNs, which would mask a real regression if anything ever did start producing them. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…er + NVTEEpHandle struct (NVTE_EP_HANDLE_CACHE_SIZE=-1 disables eviction) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…hout it Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ogging.h Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…_COPY_{ON,OFF}
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tyAllSymm Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CUDA Toolkit) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… for wheel install Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…bmodules Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rop submodule header mirror Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…al CommWindow Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF ) Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
…ero) * drop ``TestZZZTeEpMoeBootstrap``: the re-bootstrap mismatch is a one-line guard in ``ep_bootstrap`` and not the MoE block's concern; exercising it from this suite also taints the per-process NCCL bootstrap cache for the rest of the file with no real upside. * drop ``TestTeEpMoEBlockFlax::test_init_apply_parity``: every config in ``_CONFIGS`` already runs ``MoEBlock`` (the Flax wrapper) end-to-end via ``test_forward`` / ``test_backward``, so this was a duplicate of ``softmax`` parity in another wrapper -- leave wrapper refactors to devs without paying for an extra CI run each time. * drop ``sigmoid-bias-zero``: with a zero-init bias buffer the routing math collapses to the no-bias case, so ``sigmoid`` already covers that numerical path. The bias-aware codepath is still exercised by ``sigmoid-bias-strong`` (non-zero bias). * refresh the module-level docstring to list intentional non-coverage so future readers don't re-add these tests. Signed-off-by: tdophung <tdophung@nvidia.com>
… closure)
Two unrelated one-line bugs in the bwd custom_partitioning machinery
that only surface once the MoE block's aux-loss path is lifted out of
shard_map (the custom_partitioning_sharding_rule check is skipped under
shard_map, which is why these never tripped before).
1. FusedMoEAuxLossBwdPrimitive.shardy_sharding_rule:
``grad_aux_loss`` is the cotangent of a scalar loss and is rank-0;
declaring it with a spurious ``grad_one`` factor gave it rank-1 and
tripped JAX's custom_partitioning_sharding_rule rank check at global
view. Change the rule's third operand entry to empty:
"const_buf_one, num_experts, grad_one -> i num_experts"
->
"const_buf_one, num_experts, -> i num_experts"
2. FusedTopkWithScoreFunctionBwdPrimitive.partition:
``del result_infos, routing_map_format`` removed
``routing_map_format`` from the enclosing scope before the nested
``sharded_impl`` closure was invoked. Python closures resolve names
at call time, not definition time, so when XLA finally invoked
``sharded_impl`` for the bwd partitioned impl it raised
``NameError: cannot access free variable 'routing_map_format'``.
Drop ``routing_map_format`` from the ``del`` and leave a NOTE so
future cleanups don't reintroduce the bug. Sibling partition
methods (fwd topk, both aux-loss directions) already only
``del result_infos`` and need no change.
Signed-off-by: tdophung <tdophung@nvidia.com>
A dp_resource or fsdp_resource that exists in the active mesh resource config but is sized 1 in the actual mesh would still be returned by ``_ep_outer_axis()``, pinning EP-output PartitionSpecs to a degenerate axis. JAX collapses size-1 mesh axes during lowering, which made the EP-output specs reference an axis that no longer exists at runtime -- breaking shard_map output stitching on configs where DP or FSDP is optional. Treat a size-1 axis as absent: prefer dp -> fsdp, but only when the candidate axis is actually sized > 1 in the current mesh. Falls back to the previous behaviour when no axis is configured at all. Signed-off-by: tdophung <tdophung@nvidia.com>
After the upstream PR NVIDIA#3036 resync the moe() API surface lost PermutationBackend (TE-EP is the only backend now), gate_inside_vjp (always True), and the per-call quantizer_sets knob (quantization flows through the standard TE autocast / with_quantizer_set context). It also gained apply_topk_weights_early and renamed the wrapper's private _align_size to the public align_size the test suite already uses. The Flax _MoEBlock wrapper was still passing the old kwargs, which broke every test that touched the wrapper. Wrapper changes: * drop "from ..moe import PermutationBackend" plus the dataclass field, the isinstance(..., PermutationBackend) validation in __post_init__, and the pass-through to moe(). * drop "from ..quantize import noop_quantizer_set" and the quantizer_sets=(noop, noop, noop) pass-through. * drop gate_inside_vjp=True. * rename _align_size: int = 0 -> align_size: int = 0 (matches what tests/jax/test_te_ep_moe.py already passes). * add apply_topk_weights_early: bool = False and pass it through to moe(). * refresh class docstring: drop permutation_backend / _align_size / quantizer_sets descriptions, add apply_topk_weights_early / align_size, note that quantization currently flows only through fp8_autocast. Signed-off-by: tdophung <tdophung@nvidia.com>
…ices
Two correctness fixes for the TE-EP MoE custom_vjp that together let
the bwd parity tests pass on 0-token-globally experts, and drop a
workaround that is no longer needed.
(1) Plumb per-expert padded token_counts into grouped_gemm group_sizes.
NCCL EP HT dispatch lays out recv_tokens expert-major as
[expert_0_padded | expert_1_padded | ... | overalloc_tail]
where each per-expert block already includes the
dispatch_output_per_expert_alignment zero-padding and only the trailing
overalloc tail (slack between sum(token_counts) and the worst-case
recv_pr) is unused. Previously _ffn_fwd_per_shard built a static
local_group_sizes = jnp.full((num_local_experts,), slots_per_expert),
which over-counted by the overalloc tail and forced cuBLAS to run the
GEMM for every group including 0-token-routed experts.
Pipe the real per-shard token_counts (1, num_local_experts) from
ep_prepare through _moe_fwd_rule (added to ffn_in_specs/ffn_in_args
with ep2_spec), into _ffn_fwd_per_shard as token_counts_local, and
reshape into local_group_sizes for both grouped_quantize and
grouped_gemm. cuBLAS now skips both 0-token experts and the trailing
overalloc tail. Mirror the residual spec change on the bwd
(local_group_sizes residual moves from P() to ep2_spec).
(2) Per-group jnp.where zero-fill on wgrad outputs.
cuBLAS grouped_gemm skips groups with size_g == 0 without zero-filling
the corresponding out[g, :, :] slice (cublaslt_grouped_gemm.cu lines
2086/2096). For a shard hosting an expert that received zero tokens
globally, d_wo / d_wi_combined for that expert is left uninit, which
propagates NaN straight into the user's optimizer state.
Add wgrad_group_active = (local_group_sizes > 0)[:, None, None] in
_ffn_bwd_per_shard and apply via jnp.where on d_wo (right after the wo
wgrad) and d_wi_combined (right after the fused wi_0+wi_1 wgrad).
Mask shape is (num_local_experts, 1, 1) so cost is negligible.
(3) Drop the lax.cond zero-init guard on r_tok in _moe_fwd_rule._body.
Previously a jax.lax.cond(jnp.any(r_w != 0), identity, zeros_like)
wrapper around recv_tokens worked around tex.ep_dispatch_fwd leaving
the recv buffer uninit on fully-empty-receiver ranks. With (1) in
place, cuBLAS skips experts whose group_sizes == 0 and the per-row
trailing tail of dispatched recv_tokens is unread by every downstream
consumer (subsequent grouped_gemms read only sum(group_sizes) rows;
ep_combine and ep_dispatch_bwd are handle_mem-aware). The only
per-row consumer that would propagate the tail is grouped_dbias
(per-row segment_sum), which only runs when has_bias=True, and that
FFN bias path is currently gated upstream (cuBLAS grouped_gemm has
no fused bias on Hopper yet; PR 3083 adds the pure-JAX bias add).
With (2) handling the user-visible wgrad-NaN risk on 0-token experts,
the lax.cond is now redundant. Replace with a NOTE pointing at the
two follow-ups that would force its reintroduction:
- a future caller that reads the full recv tile non-group-aware
(e.g. an inspect probe), or
- the FFN bias path landing, which would resurrect grouped_dbias.
Also rewrite the _ffn_fwd_per_shard and _ffn_bwd_per_shard docstrings
to spell out the per-row vs per-group uninit semantics so the next
person debugging a NaN here has the invariants written down.
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
…IA#3116) Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias -> use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The two flags are siblings (they enable two different bias buffers) but the old names suggested ``use_bias`` was the general fallback, which wasn't the intent. The new names make the FFN-vs-routing distinction obvious from the call site. * transformer_engine/jax/flax/moe.py use_bias -> use_ffn_bias (dataclass field + branch in __call__ + docstring entry) use_expert_bias -> use_expert_routing_bias (same) * tests/jax/test_te_ep_moe.py _make_block(use_expert_bias=...) -> use_expert_routing_bias sigmoid-bias-strong config key updated _reference_kwargs_from_config now reads use_expert_routing_bias ``_MoEBlock`` is still the experimental underscore-prefixed alias (no public ``MoEBlock`` export yet), so the rename is API-safe. The pre-resync legacy tests (``test_moe_vjp.py``, ``test_multiprocess_moe_vjp.py``) are intentionally not updated -- they already reference removed APIs like ``PermutationBackend`` and need a separate post-resync cleanup pass. Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications) Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR NVIDIA#3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
…aN sanitizer * Rewrite the inline justifications added in 078a7d80 so each one reads as standalone code documentation, not as a reply to a reviewer: drop "per PR NVIDIA#3116 review", "review feedback", "Renamed from ... per PR ..." and similar PR/thread references from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py. Technical content (why the fp32 promotion is needed for the MoE silu+multiply, why _with_sharding_constraint_cast_bwd exists, physical-vs-logical axis split in moe() docstring, the 128 alignment rationale) is preserved and reframed to be useful to a reader who has no PR context. * Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs) guard. Tracing fused_topk_with_score_function.cu shows the kernel divides by sum_scores + 1e-20, so finite non-negative sigmoid scores cannot produce NaN here; the filter was only defense against upstream NaNs, which would mask a real regression if anything ever did start producing them. Signed-off-by: tdophung <tdophung@nvidia.com>
68617ea to
fe44697
Compare
The SwiGLU intermediate (activation inputs gate_proj_out/up_proj_out, silu+multiply, and activation output) was previously promoted to fp32 in _ffn_fwd_per_shard and again in _ffn_bwd_per_shard, then cast back to the wi/wo GEMM dtype. The promotion bought nothing: the activation inputs come out of the wi grouped_gemm in bf16, the activation output is consumed by the wo GEMM (or wo's quantizer for FP8/FP4) in the same dtype, and storing higher precision than either consumer is wasted bandwidth. * _ffn_fwd_per_shard: drop the .astype(jnp.float32) on gate_proj_out and up_proj_out and the trailing .astype(sorted_x.dtype). The multiply now stays in the wi GEMM output dtype end-to-end. * _ffn_bwd_per_shard: symmetric simplification. jax.vjp(act_fn, ...) runs at bf16, both d_intermediate * silu' and d_intermediate * up stay at bf16, no casts. silu' is now consistent with silu (both bf16) so the chain rule composes cleanly without the prior fp32 detour. * tests/jax/test_te_ep_moe.py::_pure_jax_moe_reference: drop the matching fp32 silu in the parity reference so the test compares bf16-vs-bf16. Parity tolerance was not loosened; expect the comparison to tighten now that both sides round silu identically. Also fix an inaccurate inline comment at the apply_topk_weights_early fwd branch: the bf16 requirement on expert_outputs is enforced by ep_bootstrap (which rejects max_token_dtype != bf16 and sizes the NCCL EP HT mega-buffer for 2-byte slots accordingly), not by a runtime assert in the combine FFI. Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR wires NVIDIA's NCCL-backed Expert Parallelism (EP) primitives into the TransformerEngine JAX MoE block, replacing the previous PURE_JAX/TRITON permutation backends with a single
Confidence Score: 3/5The EP dispatch/combine forward path may silently produce NaN outputs when recv_topk_weights contains NaN at padded slots (a condition the backward explicitly acknowledges); the bootstrap rank-grouping formula is unvalidated against the mesh layout. Two distinct correctness gaps exist on the hot path: the forward masking step omits NaN sanitization the backward carefully applies to the same tensor, and ep_bootstrap assumes a specific rank ordering never cross-checked against the active JAX mesh. transformer_engine/jax/moe.py (forward NaN masking ~line 827) and transformer_engine/jax/ep.py (bootstrap rank-layout ~line 110) Important Files Changed
Reviews (1): Last reviewed commit: "remove useless comments" | Re-trigger Greptile |
| w = recv_topk_weights[..., None].astype(expert_outputs.dtype) | ||
| mask_bool = (recv_topk_weights != 0)[..., None] | ||
| weighted = jnp.where(mask_bool, expert_outputs * w, jnp.zeros_like(expert_outputs)) | ||
| output = tex.ep_combine_fwd( | ||
| cfg, | ||
| handle_mem, | ||
| weighted, |
There was a problem hiding this comment.
NaN-in-
recv_topk_weights not sanitized in forward
The backward (_moe_bwd_rule) explicitly documents that ep_dispatch_fwd can write NaN into recv_topk_weights at padded slots and carefully sanitizes it before forming mask_bool. The forward uses mask_bool = (recv_topk_weights != 0)[..., None] without sanitization: NaN != 0 evaluates to True in IEEE 754, so padded slots with NaN weights are incorrectly treated as active, and expert_outputs * NaN = NaN propagates into the weighted buffer passed to ep_combine_fwd. If the NCCL EP combine kernel ever touches those padded positions, the output will silently contain NaN. The backward already shows the correct fix: replace with recv_w_clean = jnp.where(jnp.isnan(recv_topk_weights), 0, recv_topk_weights) before building w and mask_bool.
| ep_resource = gsr.ep_resource | ||
| if ep_resource is None: | ||
| raise ValueError( | ||
| "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" | ||
| " global_shard_guard(MeshResource(..., ep_resource=<axis name>)) before bootstrap." | ||
| ) | ||
| ep_size = get_mesh_axis_size(ep_resource) | ||
| outer_axis = _ep_outer_axis() | ||
| if outer_axis is None: | ||
| if world_size != ep_size: | ||
| raise ValueError( | ||
| f"ep_bootstrap: world_size ({world_size}) > ep_size ({ep_size}) but neither" | ||
| " MeshResource.dp_resource nor fsdp_resource is set; name the outer axis so" | ||
| " EP-output tensors can shard across EP groups." | ||
| ) | ||
| num_ep_groups = 1 | ||
| else: |
There was a problem hiding this comment.
Rank-layout assumption in
ep_bootstrap is unvalidated
dp_color = rank // ep_size hard-codes a specific rank-to-group mapping: EP ranks occupy contiguous blocks of ep_size within the global rank space. This is only correct when the JAX mesh assigns its ep_resource axis as the innermost (fastest-varying) dimension. A mesh constructed as Mesh(devices.reshape(ep, dp), ("ep", "dp")) assigns rank 1 to ep=0, dp=1, yet dp_color = 1 // ep_size = 0 incorrectly merges rank 0 and rank 1 into one communicator. There is no guard that checks mesh.device_ids ordering against the formula, so a silently wrong NCCL communicator is built.
| hidden_dim=hidden_dim, | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| def _default_out_partition_spec(): | ||
| """Leading-axis default: ``(("dp","ep"),)`` if DP/FSDP is set, else ``("ep",)``.""" | ||
| gsr = global_mesh_resource() | ||
| if gsr.ep_resource is None: | ||
| raise ValueError( | ||
| "ep_resource is not set on the active MeshResource; pass out_sharding=... explicitly." | ||
| ) | ||
| outer = _ep_outer_axis() | ||
| leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource | ||
| return (leading,) | ||
|
|
||
|
|
There was a problem hiding this comment.
_dispatch_bwd silently ignores cotangents for handle_mem and token_counts
ep_dispatch returns 4 primals, so g_outputs has 4 elements. Only g_outputs[0] and g_outputs[1] are consumed; the cotangents for handle_mem (uint8) and token_counts (int32) are silently dropped. JAX sets these to zero for non-float outputs so behaviour is correct today, but an explicit assertion on len(g_outputs) would protect against future arity changes.
| // --------------------------------------------------------------------------- | ||
|
|
||
| size_t EPBackend::cache_cap_locked() { | ||
| if (handle_cache_cap_ == 0) { | ||
| const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); | ||
| if (cap_env != nullptr) { | ||
| const int64_t v = static_cast<int64_t>(std::atol(cap_env)); | ||
| if (v < 0) { | ||
| // Unlimited cache. WAR for JAX until XLA fixes handle_mem | ||
| // reloc between runs. | ||
| handle_cache_cap_ = SIZE_MAX; | ||
| } else { | ||
| NVTE_CHECK(v > 0, | ||
| "NVTE_EP_HANDLE_CACHE_SIZE=0 is invalid; use -1 for unlimited or a positive " | ||
| "cap."); | ||
| handle_cache_cap_ = static_cast<size_t>(v); | ||
| } | ||
| } else { | ||
| handle_cache_cap_ = 4096; | ||
| } |
There was a problem hiding this comment.
fallback_layer_cfg_ limits the backend to one unique layer config per process
The fallback path reconstructs handles using a process-wide cached config, implicitly requiring all MoE layers in the process to share the same (top_k, alignment) pair. Users stacking layers with different top_k values will hit a NVTE_CHECK at runtime without a clear message pointing to this constraint. Consider documenting this in ep.h near NVTEEpLayerConfig.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Integrate the new TE EP dispatch + combine APIs in the MoEBlock, in replacement for the previous ragged-all-to-all communications and Triton permutation kernels. Router CUDA kernels + grouped GEMM is the same as before.
Will rebase and squash the commits on this branch once about to merge
Will also change the JAX APIs if needed when TE EP JAX merge
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: